前言
在上一篇文章中,我們介紹了 LSTM(長短期記憶網路)的基本架構和三個主要的 Gate 機制,包括 Input Gate、Forget Gate 和 Output Gate。這些 Gate 讓 LSTM 能夠有效地保留重要的資訊,同時過濾掉無用的部分,解決了 RNN 在處理長期依賴問題上的不足。
現在,我們將進一步探討 LSTM 如何在實際應用中運作,通過實際的例子來說明資料如何流入與流出 LSTM。
LSTM Memory Cell
LSTM 的核心是能夠控制資訊如何在不同階段流動,其中四個主要組件包括輸入資料、Input Gate、Forget Gate 和 Output Gate。而這些組件的運作依賴於各自的權重向量與偏差值(bias),這些參數是在模型訓練時透過數據( training data )學習到的。在這裡,為了更容易理解,我們先假定權重( Weight )和偏差值( Bias )如下:
假定的權重和偏差值:
-
輸入資料 (Input Data)
- 權重設定:X1 weight = 1, X2 weight = 0, X3 weight = 0, 偏差值 (bias) = 0
解釋:這邊僅有 X1 的權重為 1,所以我們直接將 X1 作為輸入,其他輸入 (X2, X3) 的權重皆為 0,沒有影響。
-
Input Gate
- 權重設定:X1 weight = 0, X2 weight = 100, X3 weight = 0, 偏差值 (bias) = -10
解釋:在這裡,X2 的權重為 100,這是唯一對 Input Gate 有影響的輸入。如果 X2 的值為 0,最終輸出會主要受偏差值 -10 的影響。當通過 sigmoid 激活函數時,這個輸出會接近 0,表示 Input Gate 被關閉。相反,如果 X2 的值較大,最終輸出將超過偏差值,結果會是接近 1 的值,表示 Input Gate 被打開。
-
Forget Gate
- 權重設定:X1 weight = 0, X2 weight = 100, X3 weight = 0, 偏差值 (bias) = 10
解釋:Forget Gate 的偏差值為 10,這代表它通常是處於打開的狀態,允許記憶單元的資訊保留。然而,當 X2 提供一個足夠大的負值時,它會壓過這個正偏差,關閉 Forget Gate,從而選擇性地忘記記憶中的某些資訊。
-
Output Gate
- 權重設定:X1 weight = 0, X2 weight = 0, X3 weight = 100, 偏差值 (bias) = -10
解釋:Output Gate 的偏差值為 -10,這代表它通常是關閉的狀態。只有當 X3 的值足夠大時,會抵消偏差的影響,從而打開 Output Gate,使得記憶單元的資訊可以輸出到下一層神經網路。
在以上假定的設定下,我們可以清楚地看到 LSTM 中的每個 Gate 是如何依賴不同輸入向量和偏差值來決定其開啟或關閉狀態。Input Gate 和 Forget Gate 透過控制資訊的進入和遺忘來維持 LSTM 的長期記憶,而 Output Gate 則負責控制輸出的訊息流到下一層。
實際 Input 資料
接下來,為了更詳細的說明 LSTM 的 cell state 和 output 的推導,我們將使用先前所給的權重和偏差來推導輸入序列如何影響 LSTM 記憶單元的狀態和輸出。在這個過程中,LSTM 的三個 Gate(Input Gate、Forget Gate 和 Output Gate)將決定信息如何進入、被保留或輸出。我們來將以下序列輸入LSTM,來他處理的過程吧
已知輸入序列:
Time |
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
x1 |
1 |
3 |
2 |
4 |
2 |
1 |
3 |
6 |
1 |
x2 |
0 |
1 |
0 |
1 |
0 |
0 |
-1 |
1 |
0 |
x3 |
0 |
0 |
0 |
0 |
0 |
1 |
0 |
0 |
1 |
LSTM 推導過程:
每個時間步的 cell state ( c_t ) 和 output ( y_t ) 都會受到這些 Gate 的影響。現在我們逐步分析各時間點下的 cell state ( c_t ) 和 output ( y_t ) 的計算過程。
時間步 1(time = 1):
-
輸入資料 ( x1 = 1 )
- ( x2 = 0 ) (Input Gate 關閉,因為 sigmoid(-10) ≈ 0)
- Forget Gate 打開 (bias = 10, sigmoid(10) ≈ 1)
- Output Gate 關閉 (bias = -10, sigmoid(-10) ≈ 0)
推導:
- ( c_1 = 0 )(因為 Forget Gate 打開,但 cell 中沒有之前的資料)
- ( y_1 = 0 )(Output Gate 關閉)
時間步 2(time = 2):
-
輸入資料 ( x1 = 3 )
- ( x2 = 1 ) (Input Gate 打開,因為 sigmoid(100 * 1 - 10) ≈ 1)
- Forget Gate 打開 (bias = 10, sigmoid(10) ≈ 1)
- Output Gate 關閉 (bias = -10, sigmoid(-10) ≈ 0)
推導:
- ( c_2 = 0 + 3 = 3 )(Input Gate 打開,輸入3進入 cell state)
- ( y_2 = 0 )(Output Gate 關閉)
時間步 3(time = 3):
-
輸入資料 ( x1 = 2 )
- ( x2 = 0 ) (Input Gate 關閉,因為 sigmoid(-10) ≈ 0)
- Forget Gate 打開 (bias = 10, sigmoid(10) ≈ 1)
- Output Gate 關閉 (bias = -10, sigmoid(-10) ≈ 0)
推導:
- ( c_3 = 3 + 0 = 3 )(Forget Gate 打開,但沒有新資料進入 cell)
- ( y_3 = 0 )(Output Gate 關閉)
時間步 4(time = 4):
-
輸入資料 ( x1 = 4 )
- ( x2 = 1 ) (Input Gate 打開,因為 sigmoid(100 * 1 - 10) ≈ 1)
- Forget Gate 打開 (bias = 10, sigmoid(10) ≈ 1)
- Output Gate 關閉 (bias = -10, sigmoid(-10) ≈ 0)
推導:
- ( c_4 = 3 + 4 = 7 )(新資料 4 被加到 cell state 中)
- ( y_4 = 0 )(Output Gate 關閉)
時間步 5(time = 5):
-
輸入資料 ( x1 = 2 )
- ( x2 = 0 ) (Input Gate 關閉,因為 sigmoid(-10) ≈ 0)
- Forget Gate 打開 (bias = 10, sigmoid(10) ≈ 1)
- Output Gate 關閉 (bias = -10, sigmoid(-10) ≈ 0)
推導:
- ( c_5 = 7 + 0 = 7 )(Forget Gate 打開,但沒有新資料進入)
- ( y_5 = 0 )(Output Gate 關閉)
時間步 6(time = 6):
-
輸入資料 ( x1 = 1 )
- ( x2 = 0 ) (Input Gate 關閉,因為 sigmoid(-10) ≈ 0)
- Forget Gate 打開 (bias = 10, sigmoid(10) ≈ 1)
- ( x3 = 1 ) (Output Gate 打開,因為 sigmoid(100 * 1 - 10) ≈ 1)
推導:
- ( c_6 = 7 + 0 = 7 )(Forget Gate 打開,沒有新資料進入 cell)
- ( y_6 = 1 )(Output Gate 打開,輸出 cell state )
時間步 7(time = 7):
-
輸入資料 ( x1 = 3 )
- ( x2 = -1 ) (Forget Gate 關閉,因為 sigmoid(100 * -1 + 10) ≈ 0)
- ( x3 = 0 ) (Output Gate 關閉,因為 sigmoid(-10) ≈ 0)
推導:
- ( c_7 = 0 )(Forget Gate 關閉,cell state 重置)
- ( y_7 = 0 )(Output Gate 關閉)
時間步 8(time = 8):
-
輸入資料 ( x1 = 6 )
- ( x2 = 1 ) (Input Gate 打開,因為 sigmoid(100 * 1 - 10) ≈ 1)
- Forget Gate 打開 (bias = 10, sigmoid(10) ≈ 1)
- Output Gate 關閉 (bias = -10, sigmoid(-10) ≈ 0)
推導:
- ( c_8 = 0 + 6 = 6 )(新資料進入 cell state)
- ( y_8 = 0 )(Output Gate 關閉)
時間步 9(time = 9):
-
輸入資料 ( x1 = 1 )
- ( x2 = 0 ) (Input Gate 關閉,因為 sigmoid(-10) ≈ 0)
- Forget Gate 打開 (bias = 10, sigmoid(10) ≈ 1)
- ( x3 = 1 ) (Output Gate 打開,因為 sigmoid(100 * 1 - 10) ≈ 1)
推導:
- ( c_9 = 6 + 0 = 6 )(Forget Gate 打開,但沒有新資料進入 cell state)
- ( y_9 = 1 )(Output Gate 打開)
最終結果:
Time |
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
c |
0 |
0 |
3 |
3 |
7 |
7 |
0 |
6 |
6 |
y |
0 |
0 |
0 |
0 |
0 |
1 |
0 |
0 |
1 |
結論
透過這個例子,我們可以更深入地了解 LSTM 是如何運作的。每個 Gate(Input Gate、Forget Gate 和 Output Gate)各自的開關狀態,確實精準地控制了資料的流動與記憶的保存或忘記。這讓 LSTM 能夠處理具有長期依賴的序列資料,解決了 RNN 在時間步長較長時容易出現的問題。當然,這些 Gate 的效果依賴於訓練數據中學到的權重和偏差值,未來如果應用在實際的數據集上,我們可以期待 LSTM 在各種時間序列預測任務中提供出色的表現。
換句話說,LSTM 的核心價值就是幫助我們在大量且複雜的數據中找到真正重要的資訊並加以記憶,同時過濾掉那些不再有用的內容。不管是處理金融市場走勢,還是語言生成模型,LSTM 都能發揮關鍵作用。
[0]Recurrent Neural Network(Professor 李宏毅 #21-1)
[1]ML Lecture 21-1: Recurrent Neural Network (Part I)